#include "model_state.h"

namespace triton { namespace backend { namespace ascend {

const uint32_t FIRST_INPUT_INDEX = 0;
const uint32_t TENSOR_BATCH_INDEX = 0;

TRITONSERVER_Error*
ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state)
{
    try {
        *state = new(std::nothrow) ModelState(triton_model);
    }
    catch (const BackendModelException& ex) {
        RETURN_ERROR_IF_TRUE(
            ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL,
            std::string("unexpected nullptr in BackendModelException"));
        RETURN_IF_ERROR(ex.err_);
    }
    // MxBase初始化
    APP_ERROR ascend_ret = MxBase::MxInit();
    RETURN_ERROR_IF_TRUE((ascend_ret != APP_ERR_OK), TRITONSERVER_ERROR_INTERNAL,
        std::string("failed to init devcie. ret: " + std::to_string(ascend_ret))
    );
    // 对config文件进行解析
    bool auto_complete_config = false;
    RETURN_IF_ERROR(
        TRITONBACKEND_ModelAutoCompleteConfig(triton_model, &auto_complete_config)
    );
    if (auto_complete_config) {
        RETURN_IF_ERROR((*state)->AutoCompleteConfig());
        RETURN_IF_ERROR((*state)->SetModelConfig());
    }
    RETURN_IF_ERROR((*state)->ParseParameters());

    return nullptr;  // success
}

// 默认不支持自动补全config,如果需要支持该功能,则需要使用mxBase接口进行查询来构建config
TRITONSERVER_Error*
ModelState::AutoCompleteConfig()
{
    LOG_MESSAGE(
        TRITONSERVER_LOG_WARN,
        (std::string("skipping model configuration auto-complete for '") +
        Name() + "': not supported for ascend backend").c_str()
    );
    return nullptr;
}

// 解析config文件,目前暂时没有需要额外输入的信息,如果需要则后续补充,接口预留
TRITONSERVER_Error*
ModelState::ParseParameters()
{
    triton::common::TritonJson::Value params;
    bool status = model_config_.Find("parameters", &params);
    // 如果找到了parameters则处理
    if (status) {

    }

    return nullptr;
}

TRITONSERVER_Error*
ModelState::LoadModel(const std::string& artifact_name, const int32_t device, std::string* model_path,
    std::shared_ptr<MxBase::Model>* ascend_model)
{
    std::string cc_model_filename = artifact_name;
    
    // 如果没有设置artifact_name,使用默认的模型名称进行索引
    if (cc_model_filename.empty()) {
        cc_model_filename = "model.om";
    }
    // 获取模型的路径并检查模型文件是否存在
    *model_path = JoinPath({RepositoryPath(), std::to_string(Version()), cc_model_filename});
    {
        bool exists;
        RETURN_IF_ERROR(FileExists(*model_path, &exists));
        RETURN_ERROR_IF_FALSE(
            exists, TRITONSERVER_ERROR_UNAVAILABLE,
            std::string("unable to find '") + *model_path + "' for model instance '" + Name() + "'"
        );
    }

    try {
        ascend_model->reset(new(std::nothrow) MxBase::Model(*model_path, device));
    }
    catch (const std::exception& ex) {
        return TRITONSERVER_ErrorNew(
            TRITONSERVER_ERROR_INTERNAL,
            std::string("failed to load model '" + Name() + "': " + ex.what()).c_str()
        );
    }
    
    std::unique_lock<std::mutex> init_lock(model_mutex_);
    if (!is_initialized_) {
        RETURN_IF_ERROR(GetModelDynamicInfo(ascend_model));
        is_initialized_ = true;
    }
    init_lock.unlock();

    return nullptr;
}

TRITONSERVER_Error*
ModelState::GetModelDynamicInfo(std::shared_ptr<MxBase::Model>* ascend_model)
{
    std::vector<std::vector<uint64_t>> input_gear = (*ascend_model)->GetDynamicGearInfo();

    std::map<MxBase::VisionDataFormat, DataFormat> mxbase_format_to_backend {
        {MxBase::VisionDataFormat::NCHW, DataFormat::NCHW},
        {MxBase::VisionDataFormat::NHWC, DataFormat::NHWC},
        // {MxBase::VisionDataFormat::ND, DataFormat::ND} // 目前版本没有ND,后续会更新
    };
    auto input_format = (*ascend_model)->GetInputFormat();

    input_format_ = mxbase_format_to_backend[input_format];
    //获取模型的分档信息
    std::vector<std::vector<int64_t>> input_shape;
    uint32_t num = (*ascend_model)->GetInputTensorNum();
    for (uint32_t i = 0; i < num; i++) {
        input_shape.emplace_back((*ascend_model)->GetInputTensorShape(i));
    }
    RETURN_IF_ERROR(GetModelDynamicType(input_shape, input_gear));

    RETURN_IF_ERROR(GetModelDynamicShape(input_gear));

    RETURN_IF_ERROR(GetModelHWStride(ascend_model));
    return nullptr;
}

TRITONSERVER_Error*
ModelState::GetModelDynamicType(std::vector<std::vector<int64_t>>& input_shape,
    std::vector<std::vector<uint64_t>>& input_gear)
{
    // 如果查询不到动态信息,并且模型输入中存在-1,则是动态shape(shape range)模型
    if (input_gear.empty()) {
        if (HasDynamicShape(input_shape)) {
            dynamic_info_.dynamic_type = DynamicType::DYNAMIC_SHAPE;
        }
    } else if (input_gear[FIRST_INPUT_INDEX].size() == 1) {
        // 如果输入format是ND,且第一维不是-1,则为动态dims
        if (input_format_ == DataFormat::ND &&
            input_shape[FIRST_INPUT_INDEX][TENSOR_BATCH_INDEX] != -1) {
            dynamic_info_.dynamic_type = DynamicType::DYNAMIC_DIMS;
        } else {
            dynamic_info_.dynamic_type = DynamicType::DYNAMIC_BATCH;
        }
    } else if (input_gear[FIRST_INPUT_INDEX].size() == 2) {
        if (input_format_ == DataFormat::ND) {
            dynamic_info_.dynamic_type = DynamicType::DYNAMIC_DIMS;
        } else {
            dynamic_info_.dynamic_type = DynamicType::DYNAMIC_HW;
        }
    } else {
        dynamic_info_.dynamic_type = DynamicType::DYNAMIC_DIMS;
    }
    return nullptr;
}

TRITONSERVER_Error* ModelState::GetModelDynamicShape(std::vector<std::vector<uint64_t>>& input_gear)
{
    switch (dynamic_info_.dynamic_type)
    {
    case DynamicType::STATIC_BATCH:
        LOG_MESSAGE(TRITONSERVER_LOG_INFO,
            std::string("Model '" + Name() + "' dynamic type is: STATIC_BATCH").c_str());
            break;   
    case DynamicType::DYNAMIC_BATCH:
        LOG_MESSAGE(TRITONSERVER_LOG_INFO,
            std::string("Model '" + Name() + "' dynamic type is: DYNAMIC_BATCH").c_str());
        for (auto item : input_gear) {
            dynamic_info_.dynamic_batch.emplace_back((uint32_t)item[0x0]);
        }
        break;
    case DynamicType::DYNAMIC_HW:
        LOG_MESSAGE(TRITONSERVER_LOG_INFO,
            std::string("Model '" + Name() + "' dynamic type is: DYNAMIC_HW").c_str());
        for (auto item : input_gear) {
            std::vector<uint32_t> dynamic_size(item.begin(), item.end());
            dynamic_info_.dynamic_size.emplace_back(dynamic_size);
        }
        break;
    case DynamicType::DYNAMIC_DIMS:
        LOG_MESSAGE(TRITONSERVER_LOG_INFO,
            std::string("Model '" + Name() + "' dynamic type is: DYNAMIC_DIMS").c_str());
        for (auto item : input_gear) {
            std::vector<uint32_t> dynamic_dims(item.begin(), item.end());
            dynamic_info_.dynamic_dims.emplace_back(dynamic_dims);
        }
        break;
    case DynamicType::DYNAMIC_SHAPE:
        LOG_MESSAGE(TRITONSERVER_LOG_INFO,
            std::string("Model '" + Name() + "' dynamic type is: DYNAMIC_SHAPE").c_str());
        break;
    default:
        return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INVALID_ARG,
            std::string("'" + Name() + "' Undefined dynamic type.").c_str());
    }
    return nullptr;
}

TRITONSERVER_Error*
ModelState::GetModelHWStride(std::shared_ptr<MxBase::Model>* ascend_model)
{
    // 只有是动态分辨率的时候才可以做hw的映射裁剪
    if (dynamic_info_.dynamic_type == DynamicType::DYNAMIC_HW && CheckDataFormat(output_format_)) {
        // 获取档位信息
        LOG_MESSAGE(TRITONSERVER_LOG_INFO, std::string("Get model feature map stride").c_str());
        
        std::vector<std::vector<uint32_t>> dynamic_size = dynamic_info_.dynamic_size;
        if (dynamic_size.size() < 2) {
            LOG_MESSAGE(TRITONSERVER_LOG_INFO,
                std::string("model input gear is less than 2, canot get feature map stride").c_str());
            return nullptr;
        }
        // 取front 和 back
        std::sort(dynamic_size.begin(), dynamic_size.end());
        model_stride_info_.min_input_gear.emplace_back(dynamic_size.front()[0]);
        model_stride_info_.min_input_gear.emplace_back(dynamic_size.front()[1]);

        RETURN_IF_ERROR(GetModelInferOutputShape(ascend_model, model_stride_info_.min_input_gear,
            model_stride_info_.min_output_gear));

        model_stride_info_.max_input_gear.emplace_back(dynamic_size.back()[0]);
        model_stride_info_.max_input_gear.emplace_back(dynamic_size.back()[1]);

        RETURN_IF_ERROR(GetModelInferOutputShape(ascend_model, model_stride_info_.max_input_gear,
            model_stride_info_.max_output_gear));
    }
    return nullptr;
}

TRITONSERVER_Error*
ModelState::GetModelInferOutputShape(std::shared_ptr<MxBase::Model>* ascend_model,
    std::vector<uint32_t>& input_gear, std::vector<uint32_t>& output_gear)
{
    std::cout << "Get Output shape start ." << std::endl;
    std::vector<MxBase::Tensor> input_tensors;
    std::vector<int64_t> data_buffer {};
    uint32_t nums = (*ascend_model)->GetInputTensorNum();
    for (uint32_t i = 0; i < nums; i++) {
        auto shape = (*ascend_model)->GetInputTensorShape(i);
        if (input_format_ == DataFormat::NCHW) {
            shape[NCHW_HEIGHT_INDEX] = (int64_t)input_gear[0];
            shape[NCHW_WIDTH_INDEX] = (int64_t)input_gear[1];
        } else if (input_format_ == DataFormat::NHWC) {
            shape[NHWC_HEIGHT_INDEX] = (int64_t)input_gear[0];
            shape[NHWC_WIDTH_INDEX] = (int64_t)input_gear[1];
        }
        std::vector<uint32_t> input_shape(shape.begin(), shape.end());
        auto data_type = (*ascend_model)->GetInputTensorDataType(i);
        MxBase::Tensor temp_tensor(input_shape, data_type, -1);
        size_t byte_size = temp_tensor.GetByteSize();
        void* buffer = malloc(byte_size);
        data_buffer.emplace_back((int64_t)buffer);
        // APP_ERROR ret = MxBase::Tensor::TensorMalloc(input_tensor);
        // if (ret != APP_ERR_OK) {
        //     return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL,
        //         std::string("failed to malloc input tensors.").c_str());
        // }
        input_tensors.emplace_back(MxBase::Tensor(buffer, input_shape, data_type, -1));
    }
    auto output_tensors = (*ascend_model)->Infer(input_tensors);
    if (output_tensors.empty()) {
        return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL,
            (std::string("failed to get model '") + Name() + "' feature map stride").c_str());
    }
    for (size_t i = 0; i < output_tensors.size(); i++) {
        auto output_shape = output_tensors[i].GetShape();
        DataFormat output_format = output_format_[i];
        if (output_format == DataFormat::NCHW) {
            output_gear.emplace_back(output_shape[NCHW_HEIGHT_INDEX]);
            output_gear.emplace_back(output_shape[NCHW_WIDTH_INDEX]);
            std::cout << output_shape[NCHW_HEIGHT_INDEX] << std::endl;
        } else if (output_format == DataFormat::NHWC) {
            output_gear.emplace_back(output_shape[NHWC_HEIGHT_INDEX]);
            output_gear.emplace_back(output_shape[NHWC_WIDTH_INDEX]);
        }
    }
    for (size_t i = 0; i < data_buffer.size(); i++) {
        free((void*)data_buffer[i]);
    }
    std::cout << "Get Output shape ok ." << std::endl;
    return nullptr;
}


ModelState::ModelState(TRITONBACKEND_Model* triton_model) : BackendModel(triton_model)
{
    std::map<std::string, DataFormat> triton_format_to_backend {
        {"FORMAT_NCHW", DataFormat::NCHW},
        {"FORMAT_NHWC", DataFormat::NHWC},
        {"FORMAT_NONE", DataFormat::ND}
    };
    // 记录输入config的输出的format
    triton::common::TritonJson::Value ios;
    std::vector<std::string> output_format;
    THROW_IF_BACKEND_MODEL_ERROR(ModelConfig().MemberAsArray("output", &ios));
    for (size_t i = 0; i < ios.ArraySize(); i++) {
        triton::common::TritonJson::Value io;
        THROW_IF_BACKEND_MODEL_ERROR(ios.IndexAsObject(i, &io));

        std::string format;
        THROW_IF_BACKEND_MODEL_ERROR(io.MemberAsString("format", &format));
        LOG_MESSAGE(TRITONSERVER_LOG_INFO,
            (std::string("Model '") + Name() + "' num :" + std::to_string(i) + " output format is: " + format).c_str());
        output_format_.emplace_back(triton_format_to_backend[format]);
    }
}
}}};